This demo aims at providing an overview of applying Random Survival Forests on simulated data. We are splitting the demo in the following steps
Dataset simulation:
We simulate data and also right-censored survival outcome, using a proportional hazard model with time-constant baseline hazard, such as
\(Surv(time,status) \sim HCT+BPSYS+trt+trt:BMI\)
Where:
The dataset is small on purpose, so the knitting runtime is short. Still some chunks might take a bit longer to run . We set these chunks to eval=FALSE and saved their outcomes in .qs files. Those files are being read in the next chunk. If one wants to run these chunks please set eval to TRUE
# load library TV
library(tidyverse)
#load utility packages
library(kableExtra)
library(here)
library(tictoc)
library(stringr)
library(qs)
library(skimr)
library(ggplot2)data<-qs::qread(here::here("Data","Demo_data_tte.qs"))data%>% skimr::skim()| Name | Piped data |
| Number of rows | 289 |
| Number of columns | 38 |
| _______________________ | |
| Column type frequency: | |
| factor | 8 |
| numeric | 30 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| .trt | 0 | 1.00 | FALSE | 2 | TRT: 153, PLC: 136 |
| SEX | 0 | 1.00 | FALSE | 2 | M: 148, F: 141 |
| RACE | 20 | 0.93 | FALSE | 3 | WHI: 199, ASI: 40, BLA: 30 |
| atrial_fibrillation | 0 | 1.00 | FALSE | 2 | yes: 176, no: 113 |
| myocardial_infarction | 0 | 1.00 | FALSE | 2 | no: 219, yes: 70 |
| coronary_artery_disease | 0 | 1.00 | FALSE | 2 | no: 243, yes: 46 |
| ventricular_tachycardia | 0 | 1.00 | FALSE | 2 | no: 278, yes: 11 |
| angina_pectoris | 0 | 1.00 | FALSE | 2 | no: 279, yes: 10 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| .id | 0 | 1.00 | 10159.92 | 93.55 | 10001.00 | 10080.00 | 10159.00 | 10241.00 | 10320.00 | ▇▇▇▇▇ |
| .time | 0 | 1.00 | 126.96 | 85.66 | 0.00 | 24.54 | 200.00 | 200.00 | 200.00 | ▅▁▁▁▇ |
| .status | 0 | 1.00 | 0.28 | 0.45 | 0.00 | 0.00 | 0.00 | 1.00 | 1.00 | ▇▁▁▁▃ |
| AGE | 0 | 1.00 | 67.31 | 13.52 | 45.00 | 57.00 | 67.00 | 80.00 | 92.00 | ▇▇▇▆▆ |
| CALCIUM | 3 | 0.99 | 9.50 | 0.43 | 8.17 | 9.21 | 9.51 | 9.80 | 10.74 | ▁▅▇▆▁ |
| CREAT | 4 | 0.99 | 1.23 | 0.33 | 0.65 | 1.00 | 1.19 | 1.40 | 2.54 | ▅▇▃▁▁ |
| GGT | 2 | 0.99 | 59.41 | 70.68 | 1.81 | 21.46 | 38.30 | 70.65 | 775.85 | ▇▁▁▁▁ |
| HB | 4 | 0.99 | 13.59 | 1.58 | 10.21 | 12.43 | 13.48 | 14.49 | 19.82 | ▃▇▅▂▁ |
| HCT | 3 | 0.99 | 41.14 | 4.72 | 28.87 | 37.87 | 40.72 | 44.50 | 56.67 | ▂▇▇▃▁ |
| HDL | 2 | 0.99 | 47.57 | 15.08 | 18.61 | 36.58 | 45.77 | 54.98 | 105.61 | ▃▇▃▁▁ |
| LDL | 1 | 1.00 | 86.48 | 34.45 | 27.35 | 61.72 | 82.29 | 103.90 | 255.75 | ▆▇▂▁▁ |
| MAGNES | 1 | 1.00 | 2.19 | 0.28 | 1.46 | 1.98 | 2.17 | 2.38 | 2.98 | ▂▆▇▅▁ |
| POTASS | 3 | 0.99 | 4.37 | 0.45 | 3.26 | 4.05 | 4.37 | 4.66 | 5.76 | ▂▆▇▂▁ |
| SODIUM | 3 | 0.99 | 138.43 | 2.77 | 129.19 | 136.53 | 138.41 | 140.32 | 147.43 | ▁▃▇▃▁ |
| URICAC | 5 | 0.98 | 7.78 | 2.28 | 3.43 | 6.16 | 7.42 | 9.20 | 20.94 | ▆▇▂▁▁ |
| BMI | 2 | 0.99 | 27.84 | 4.82 | 16.70 | 24.20 | 28.00 | 31.05 | 42.10 | ▂▇▇▃▁ |
| BPDIA | 4 | 0.99 | 71.55 | 10.18 | 43.00 | 64.00 | 71.00 | 78.00 | 100.00 | ▁▅▇▅▁ |
| BPSYS | 2 | 0.99 | 117.78 | 14.45 | 83.00 | 107.50 | 118.00 | 128.00 | 162.00 | ▂▆▇▃▁ |
| HR | 3 | 0.99 | 68.96 | 11.01 | 39.00 | 62.00 | 69.00 | 77.00 | 98.00 | ▁▅▇▅▁ |
| WEIGHT | 2 | 0.99 | 82.23 | 17.51 | 36.10 | 68.75 | 81.50 | 96.10 | 126.20 | ▁▇▇▇▂ |
| noise1 | 0 | 1.00 | -0.03 | 0.97 | -2.76 | -0.69 | 0.03 | 0.69 | 2.17 | ▁▃▇▆▂ |
| noise2 | 0 | 1.00 | 0.06 | 1.01 | -2.75 | -0.53 | 0.03 | 0.81 | 3.20 | ▁▅▇▅▁ |
| noise3 | 0 | 1.00 | 0.00 | 1.01 | -2.68 | -0.64 | -0.02 | 0.66 | 3.54 | ▂▆▇▂▁ |
| noise4 | 0 | 1.00 | 0.02 | 0.98 | -3.70 | -0.63 | 0.05 | 0.63 | 2.81 | ▁▂▇▇▁ |
| noise5 | 0 | 1.00 | 0.03 | 1.00 | -3.02 | -0.63 | 0.07 | 0.71 | 2.86 | ▁▃▇▅▁ |
| noise6 | 0 | 1.00 | -0.05 | 0.99 | -3.04 | -0.66 | -0.07 | 0.65 | 3.20 | ▁▅▇▃▁ |
| noise7 | 0 | 1.00 | -0.04 | 1.03 | -3.63 | -0.76 | 0.00 | 0.59 | 2.87 | ▁▃▇▇▁ |
| noise8 | 0 | 1.00 | -0.07 | 1.00 | -2.53 | -0.74 | -0.06 | 0.57 | 2.82 | ▂▆▇▃▁ |
| noise9 | 0 | 1.00 | 0.05 | 1.00 | -2.91 | -0.56 | 0.01 | 0.70 | 2.59 | ▁▃▇▆▂ |
| noise10 | 0 | 1.00 | 0.05 | 1.03 | -2.81 | -0.61 | 0.13 | 0.75 | 3.06 | ▁▅▇▅▁ |
We can start by loading the tidymodels metapackage and splitting our data into training and testing sets
library(tidymodels)
set.seed(123)
#create a single binary split of the data into a training set and testing set
data_split <- rsample::initial_split(data, strata = .status)
#extract the resulting data
data_train <- rsample::training(data_split)
data_test <- rsample::testing(data_split)We pre-process the training data and apply exactly the same step to the test data.
#textrecipes contain extra steps for the recipes package for preprocessing text data.
library(textrecipes)
# make a recipe ####
tte_recipe <-
recipes::recipe(formula = .time + .status ~ ., data = data_train) %>%
recipes::update_role(.id, new_role = "id") %>%
recipes::update_role(c(.time, .status), new_role = "outcome") %>%
recipes::step_impute_knn(recipes::all_predictors(), -.trt) %>%
recipes::step_naomit(recipes::all_predictors()) %>%
recipes::step_nzv(recipes::all_predictors(),
freq_cut = 95 / 5,
unique_cut = 10) %>%
recipes::step_normalize(recipes::all_numeric_predictors()) %>%
recipes::step_corr(recipes::all_numeric_predictors(), threshold = 0.9)#%>%
#prepare new data####
prep_tte_recipe <- tte_recipe %>%
recipes::prep()
prep_data_test <-
recipes::bake(object = prep_tte_recipe, new_data = data_test)
prep_data_train <- recipes::juice(prep_tte_recipe)
# inspect data ####
data_prep <- prep_data_train %>%
bind_rows(prep_data_test)
data_prep %>%
skimr::skim()| Name | Piped data |
| Number of rows | 289 |
| Number of columns | 36 |
| _______________________ | |
| Column type frequency: | |
| factor | 7 |
| numeric | 29 |
| ________________________ | |
| Group variables | None |
Variable type: factor
| skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
|---|---|---|---|---|---|
| .trt | 0 | 1 | FALSE | 2 | TRT: 153, PLC: 136 |
| SEX | 0 | 1 | FALSE | 2 | M: 148, F: 141 |
| RACE | 0 | 1 | FALSE | 3 | WHI: 217, ASI: 41, BLA: 31 |
| atrial_fibrillation | 0 | 1 | FALSE | 2 | yes: 176, no: 113 |
| myocardial_infarction | 0 | 1 | FALSE | 2 | no: 219, yes: 70 |
| coronary_artery_disease | 0 | 1 | FALSE | 2 | no: 243, yes: 46 |
| ventricular_tachycardia | 0 | 1 | FALSE | 2 | no: 278, yes: 11 |
Variable type: numeric
| skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
|---|---|---|---|---|---|---|---|---|---|---|
| .id | 0 | 1 | 10159.92 | 93.55 | 10001.00 | 10080.00 | 10159.00 | 10241.00 | 10320.00 | ▇▇▇▇▇ |
| AGE | 0 | 1 | -0.03 | 1.02 | -1.71 | -0.81 | -0.06 | 0.92 | 1.82 | ▇▇▇▆▆ |
| CALCIUM | 0 | 1 | 0.00 | 1.01 | -3.16 | -0.66 | 0.04 | 0.71 | 2.96 | ▁▅▇▆▁ |
| CREAT | 0 | 1 | 0.03 | 1.06 | -1.88 | -0.72 | -0.10 | 0.58 | 4.31 | ▅▇▃▁▁ |
| GGT | 0 | 1 | 0.05 | 1.04 | -0.80 | -0.51 | -0.26 | 0.21 | 10.63 | ▇▁▁▁▁ |
| HCT | 0 | 1 | 0.00 | 0.96 | -2.50 | -0.67 | -0.09 | 0.65 | 3.16 | ▂▇▇▃▁ |
| HDL | 0 | 1 | -0.07 | 0.99 | -1.98 | -0.78 | -0.19 | 0.43 | 3.77 | ▃▇▃▁▁ |
| LDL | 0 | 1 | 0.00 | 1.04 | -1.79 | -0.74 | -0.12 | 0.53 | 5.13 | ▆▇▂▁▁ |
| MAGNES | 0 | 1 | 0.03 | 0.99 | -2.49 | -0.69 | -0.02 | 0.71 | 2.79 | ▂▆▇▅▁ |
| POTASS | 0 | 1 | -0.05 | 0.95 | -2.38 | -0.72 | -0.02 | 0.57 | 2.88 | ▂▆▇▂▁ |
| SODIUM | 0 | 1 | -0.01 | 1.04 | -3.51 | -0.72 | -0.02 | 0.70 | 3.39 | ▁▃▇▃▁ |
| URICAC | 0 | 1 | 0.01 | 0.99 | -1.89 | -0.69 | -0.14 | 0.62 | 5.74 | ▆▇▂▁▁ |
| BMI | 0 | 1 | 0.02 | 0.99 | -2.28 | -0.73 | 0.03 | 0.67 | 2.97 | ▂▇▇▃▁ |
| BPDIA | 0 | 1 | 0.06 | 1.02 | -2.81 | -0.70 | 0.00 | 0.70 | 2.91 | ▁▅▇▅▁ |
| BPSYS | 0 | 1 | 0.07 | 1.02 | -2.39 | -0.62 | 0.08 | 0.79 | 3.19 | ▂▆▇▃▁ |
| HR | 0 | 1 | 0.05 | 1.03 | -2.76 | -0.60 | 0.06 | 0.82 | 2.79 | ▁▅▇▅▁ |
| WEIGHT | 0 | 1 | 0.03 | 0.99 | -2.59 | -0.72 | -0.01 | 0.81 | 2.51 | ▁▇▇▇▂ |
| noise1 | 0 | 1 | -0.02 | 0.99 | -2.81 | -0.70 | 0.04 | 0.71 | 2.22 | ▁▃▇▆▂ |
| noise2 | 0 | 1 | -0.03 | 1.03 | -2.89 | -0.63 | -0.06 | 0.73 | 3.16 | ▁▅▇▅▁ |
| noise3 | 0 | 1 | 0.05 | 1.01 | -2.62 | -0.58 | 0.04 | 0.71 | 3.59 | ▂▆▇▂▁ |
| noise4 | 0 | 1 | 0.02 | 1.01 | -3.80 | -0.65 | 0.05 | 0.65 | 2.89 | ▁▂▇▇▁ |
| noise5 | 0 | 1 | -0.04 | 1.01 | -3.12 | -0.71 | 0.00 | 0.65 | 2.82 | ▁▃▇▅▁ |
| noise6 | 0 | 1 | -0.04 | 1.00 | -3.07 | -0.66 | -0.07 | 0.66 | 3.24 | ▁▅▇▃▁ |
| noise7 | 0 | 1 | -0.06 | 0.99 | -3.51 | -0.76 | -0.03 | 0.55 | 2.73 | ▁▃▇▇▁ |
| noise8 | 0 | 1 | 0.05 | 1.00 | -2.41 | -0.62 | 0.05 | 0.68 | 2.93 | ▂▆▇▃▁ |
| noise9 | 0 | 1 | -0.02 | 1.01 | -3.01 | -0.63 | -0.06 | 0.64 | 2.54 | ▁▃▇▆▂ |
| noise10 | 0 | 1 | -0.01 | 1.02 | -2.84 | -0.66 | 0.07 | 0.68 | 2.97 | ▁▅▇▅▁ |
| .time | 0 | 1 | 126.96 | 85.66 | 0.00 | 24.54 | 200.00 | 200.00 | 200.00 | ▅▁▁▁▇ |
| .status | 0 | 1 | 0.28 | 0.45 | 0.00 | 0.00 | 0.00 | 1.00 | 1.00 | ▇▁▁▁▃ |
library(mlr3learners)#extend mlr3 package withpopular learners, need it for using ranger
library(mlr3proba)#supports survival analysis
library(mlr3)#learners
# construction of Survival task ####
# First we put the data into an efficient memory data.table
# create instance
data_use <-
mlr3::DataBackendDataTable$new(
data = prep_data_train %>%
dplyr::mutate(.id = as.integer(.id)) %>% data.table::as.data.table(),
primary_key = ".id"
)
# Specify the survival task, create new instance
surv_task <- mlr3proba::TaskSurv$new(
id = "surv_example",
backend = data_use,
time = ".time",
event = ".status",
type = c("right")
)Kaplan-Meier curve
# Explore Kaplan-Meier curve
mlr3viz::autoplot(surv_task)#built learner
ranger_lrn <- mlr3::lrn(
"surv.ranger",
respect.unordered.factors = "order",
verbose = FALSE,
importance = "permutation"
) #Variable importance mode
# Inspect parameters
ranger_lrn$param_set## <ParamSet>
## id class lower upper nlevels default
## 1: num.trees ParamInt 1 Inf Inf 500
## 2: mtry ParamInt 1 Inf Inf <NoDefault[3]>
## 3: importance ParamFct NA NA 4 <NoDefault[3]>
## 4: write.forest ParamLgl NA NA 2 TRUE
## 5: min.node.size ParamInt 1 Inf Inf 5
## 6: replace ParamLgl NA NA 2 TRUE
## 7: sample.fraction ParamDbl 0 1 Inf <NoDefault[3]>
## 8: splitrule ParamFct NA NA 4 logrank
## 9: num.random.splits ParamInt 1 Inf Inf 1
## 10: max.depth ParamInt -Inf Inf Inf
## 11: alpha ParamDbl -Inf Inf Inf 0.5
## 12: minprop ParamDbl -Inf Inf Inf 0.1
## 13: regularization.factor ParamUty NA NA Inf 1
## 14: regularization.usedepth ParamLgl NA NA 2 FALSE
## 15: seed ParamInt -Inf Inf Inf
## 16: split.select.weights ParamDbl 0 1 Inf <NoDefault[3]>
## 17: always.split.variables ParamUty NA NA Inf <NoDefault[3]>
## 18: respect.unordered.factors ParamFct NA NA 3 ignore
## 19: scale.permutation.importance ParamLgl NA NA 2 FALSE
## 20: keep.inbag ParamLgl NA NA 2 FALSE
## 21: holdout ParamLgl NA NA 2 FALSE
## 22: num.threads ParamInt 1 Inf Inf 1
## 23: save.memory ParamLgl NA NA 2 FALSE
## 24: verbose ParamLgl NA NA 2 TRUE
## 25: oob.error ParamLgl NA NA 2 TRUE
## id class lower upper nlevels default
## value
## 1:
## 2:
## 3: permutation
## 4:
## 5:
## 6:
## 7:
## 8:
## 9:
## 10:
## 11:
## 12:
## 13:
## 14:
## 15:
## 16:
## 17:
## 18: order
## 19:
## 20:
## 21:
## 22: 1
## 23:
## 24: FALSE
## 25:
## value
Now it’s time to tune!
We will tune the following parameters for random forest:
We will use mlr3 library for building a survival random forest and tunning the hyperparameters.
library(paradox)
search_space <- paradox::ps(
num.trees = paradox::p_int(lower = 500, upper = 2000),
mtry = paradox::p_int(
lower = floor(length(surv_task$col_roles$feature) * 0.1),
upper = floor(length(surv_task$col_roles$feature) * 0.9)
),
min.node.size = paradox::p_int(lower = 1, upper = 40)
)
search_space## <ParamSet>
## id class lower upper nlevels default value
## 1: num.trees ParamInt 500 2000 1501 <NoDefault[3]>
## 2: mtry ParamInt 3 29 27 <NoDefault[3]>
## 3: min.node.size ParamInt 1 40 40 <NoDefault[3]>
We need to specify how to evaluate the performance of a trained model. For this, we need to choose a resampling strategy and a performance measure. Here we choose cross-validation and C-index
library(mlr3tuning)
#choose strategy and measure ####
#3-fold cross validation
hout <- mlr3::rsmp("cv", folds = 3)
measure <- mlr3::msr("surv.cindex")
#Terminator that stops after a number of evaluations
evals5 = mlr3tuning::trm("evals", n_evals = 5)#generate tuning instance, from task, learner, search space, resampling method and measure
instance <- mlr3tuning::TuningInstanceSingleCrit$new(
task = surv_task,
learner = ranger_lrn,
resampling = hout,
measure = measure,
search_space = search_space,
terminator = evals5
)tuner <- mlr3tuning::tnr("grid_search", resolution = 5)Note 1: actual_tuning chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)
Note 2: We use parallelization, even though dataset small, for training purposes
#packages needed for parallelization
library(doFuture)
library(doRNG)
library(foreach)
tictoc::tic()
# enable parallel processing
doFuture::registerDoFuture()
future::plan(future::multisession, workers = availableCores() - 1)
# specify seed
doRNG::registerDoRNG(seed = 123)
tuner$optimize(instance)
# disable parallel backend
foreach::registerDoSEQ()
tictoc::toc()
#Uncomment next line to save the outcome
qs::qsave(instance, here::here("Data", "htune_demo.qs"))How did all the possible parameter combinations do?
instance <- qs::qread(here::here("Data", "htune_demo.qs"))
hyparams <- instance$search_space$ids()
perf_data <- instance$archive$data
perf_data %>%
select(num.trees, mtry, min.node.size, surv.harrell_c) %>%
arrange(desc(surv.harrell_c)) %>%
mutate(surv.harrell_c = surv.harrell_c %>% round(., digits = 4)) %>%
kableExtra::kable(escape = FALSE) %>%
kableExtra::kable_styling(
bootstrap_options = "striped",
full_width = FALSE,
position = "left"
) %>%
kableExtra::column_spec(1, bold = TRUE) %>%
kableExtra::row_spec(0,
bold = TRUE,
background = "#00617F",
color = "white")| num.trees | mtry | min.node.size | surv.harrell_c |
|---|---|---|---|
| 1250 | 9 | 10 | 0.8597 |
| 1250 | 9 | 30 | 0.8459 |
| 1250 | 15 | 30 | 0.8386 |
| 875 | 28 | 30 | 0.8305 |
| 875 | 15 | 1 | 0.8264 |
Change hyperparameters to those selected in the tuning step
# adding best hyperparameters
ranger_lrn$param_set$values <- c(
ranger_lrn$param_set$values,
perf_data %>%
select(num.trees, mtry, min.node.size, surv.harrell_c) %>%
arrange(desc(surv.harrell_c)) %>%
select(-surv.harrell_c) %>%
slice(1)
)Train the final learner
set.seed(1234)
final_rf <- ranger_lrn$train(task = surv_task)Testing
# predict the outcome with the test data
pred_test <- final_rf$predict_newdata(newdata = prep_data_test)
# Define the performance metrics
pred_measures <- suppressWarnings(mlr3::msrs("surv.cindex"))
# Estimate performance
test_performance <- pred_test$score(
measures = pred_measures,
task = surv_task,
learner = final_rf,
train_set = surv_task$row_ids
) %>%
tibble::enframe(name = ".metric", value = ".estimate")
#print performance
test_performance %>%
kableExtra::kable(escape = FALSE) %>%
kableExtra::kable_styling(
bootstrap_options = "striped",
full_width = FALSE,
position = "left"
) %>%
kableExtra::column_spec(1, bold = TRUE) %>%
kableExtra::row_spec(0,
bold = TRUE,
background = "#00617F",
color = "white") | .metric | .estimate |
|---|---|
| surv.harrell_c | 0.927572 |
Training
# predict the outcome with the test data
pred_train <- final_rf$predict_newdata(newdata = prep_data_train)
# Estimate performance
train_performance <- pred_train$score(
measures = pred_measures,
task = surv_task,
learner = final_rf,
train_set = surv_task$row_ids
) %>%
tibble::enframe(name = ".metric", value = ".estimate")
# Print performance
train_performance %>%
kableExtra::kable(escape = FALSE) %>%
kableExtra::kable_styling(
bootstrap_options = "striped",
full_width = FALSE,
position = "left"
) %>%
kableExtra::column_spec(1, bold = TRUE) %>%
kableExtra::row_spec(0,
bold = TRUE,
background = "#00617F",
color = "white") | .metric | .estimate |
|---|---|
| surv.harrell_c | 0.9606475 |
Lastly, let’s learn about feature importance for this model using the vip package. For a ranger model, we do need to add in the engine importance = “permutation”, in order to compute feature importance.
importance <- final_rf$importance() %>%
as_tibble_col() %>%
bind_cols(variables = final_rf$importance() %>% names()) %>%
relocate(variables)
importance %>%
DT::datatable(
rownames = TRUE,
filter = "top",
selection = "single",
extensions = c("Buttons"),
options = list(
lengthMenu = c(5, 10, 25, 50),
pageLength = 5,
scrollX = TRUE,
dom = "lfrtBpi",
buttons = list("excel")
)
)# top 10
vi_nplot <- 10
#plot permutation importance
imp_fr_plt <- importance %>%
dplyr::arrange(., desc(value)) %>%
dplyr::slice(1:vi_nplot) %>%
dplyr::mutate(Sign = as.factor(ifelse(value > 0, "positive", "negative")))
p <- imp_fr_plt %>%
ggplot(aes(
y = reorder(variables, value),
x = value,
fill = Sign
)) +
geom_col() +
scale_fill_manual(values = c("#00659C", "#930A34")) +
theme(
legend.position = "none",
axis.title.y = element_blank(),
plot.subtitle = element_text(size = 11),
plot.title.position = "plot",
plot.margin = margin(r = 20)
) +
labs(subtitle = paste0("The top ", vi_nplot, " Variables based on Permutation"))
#plot output
pWe will show partial dependency plots. We will use all available data: train + test
library(tidyverse)
library(ranger)
rf_model <- final_rf$model
# define time points represented by rank/order they appear
time_points <-
seq(
from = 1,
to = length(rf_model$unique.death.times),
length.out = 10
) %>% round()Note: PDP predictions chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)
# define feature and grid (based onf feature class)
feat <- "HCT"
feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()
n_grid <- 50
if (feat_cat == "numeric") {
feat_range <- data_prep %>%
dplyr::pull(!!feat) %>%
range()
feat_grid <-
seq(from = feat_range[1],
to = feat_range[2],
length.out = n_grid)
} else {
feat_grid <- data_prep %>%
dplyr::pull(!!feat) %>%
levels()
}
# replace corresponding feature values with grid values
data_sets <- purrr::map(feat_grid,
~ data_prep %>% dplyr::mutate(dplyr::across(
tidyselect::all_of(feat),
.fns = function(y)
.x
)))
# calculate predictions for specified time points ####
preds <- map(
data_sets,
~ ranger:::predict.ranger(rf_model, data = .x)$survival %>%
tibble::as_tibble() %>%
dplyr::select(all_of(time_points))
)
# calculate PDPs (average (survival probability) for each feature grid value per timepoint)
pdp_data <- purrr::map2(
preds,
feat_grid,
~ .x %>%
apply(2, mean) %>%
tibble::enframe(value = ".value", name = "time_id") %>%
dplyr::mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
dplyr::mutate(feat_val1 = .y)
) %>%
dplyr::bind_rows() %>%
dplyr::rename({
{
feat
}
} := feat_val1)
#Uncomment next line to save the outcome
qs::qsave(pdp_data, here::here("Data", "hct_pdp_preds.qs"))Plot PDPs
feat <- "HCT"
pdp_data <- qs::qread(here::here("Data", "hct_pdp_preds.qs"))
feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()
#create rug for HCT
hct_rug <- data_prep %>%
dplyr::pull(!!feat)
#plot
p <- pdp_data %>%
ggplot(aes(x = !!rlang::sym(feat), y = 1 - .value)) + # 1 minus for event probability, not survival prob
{
if (feat_cat == "numeric")
geom_line()
else
geom_col()
} +
facet_wrap(
~ time_id,
nrow = 2,
labeller = ggplot2::labeller(
time_id = function(s) {
rf_model$unique.death.times[as.numeric(s)] %>% round(3)
},
# construct time labels within the function
.default = ggplot2::label_value
)
) +
ylab("event probability")
p + ggplot2::geom_rug(
data = hct_rug %>%
tibble::enframe(),
mapping = ggplot2::aes(x = value),
inherit.aes = F,
sides = "b",
alpha = 1,
col = "#B3B3B3"
)Calculate PDPs
# define feature and grid (based on feature class)
feat <- ".trt"
feat_cat <- data_prep %>% dplyr::pull(!!feat) %>% class()
n_grid <- 50
if (feat_cat == "numeric") {
feat_range <- data_prep %>%
dplyr::pull(!!feat) %>%
range()
feat_grid <-
seq(from = feat_range[1],
to = feat_range[2],
length.out = n_grid)
} else {
feat_grid <- data_prep %>%
dplyr::pull(!!feat) %>%
levels()
}
# replace corresponding feature values with grid values
data_sets <- purrr::map(feat_grid,
~ data_prep %>% dplyr::mutate(dplyr::across(
tidyselect::all_of(feat),
.fns = function(y)
.x
)))
# calculate predictions for specified time points
preds <- map(
data_sets,
~ ranger:::predict.ranger(rf_model, data = .x)$survival %>%
tibble::as_tibble() %>%
dplyr::select(all_of(time_points))
)
# calculate PDPs (average (survival probability) for each feature grid value per timepoint)
pdp_data <- purrr::map2(
preds,
feat_grid,
~ .x %>%
apply(2, mean) %>%
tibble::enframe(value = ".value", name = "time_id") %>%
dplyr::mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
dplyr::mutate(feat_val1 = .y)
) %>%
dplyr::bind_rows() %>%
dplyr::rename({
{
feat
}
} := feat_val1)Plot PDPs
p<-pdp_data %>%
ggplot(aes(x = !!rlang::sym(feat), y = 1 - .value)) + # 1 minus for event probability, not survival prob
{ if (feat_cat == "numeric") geom_line() else geom_col()} +
facet_wrap(~ time_id,
nrow = 2,
labeller = ggplot2::labeller(time_id = function(s) {rf_model$unique.death.times[as.numeric(s)] %>% round(3) }, # construct time labels within the function
.default = ggplot2::label_value)) +
ylab("event probability")
pCalculate 2D PDPs for BMI,.trt Note: PDP predictions chunk set to eval=FALSE If one wants to run this chunk needs to set eval=TRUE. If one wants to save the outcome needs to uncoment the relevant line(see comments in the code)
# define grid ####
feat <- c("BMI", ".trt")
n_grid <- 50
#create range for BMI
bmi_range <- data_prep %>%
dplyr::pull(BMI) %>%
range()
bmi_grid <-
seq(from = bmi_range[1],
to = bmi_range[2],
length.out = n_grid)
#get trt levels
trt_grid <- data_prep %>%
dplyr::pull(.trt) %>%
levels()
# replace corresponding feature values with grid values
data_sets <-
tidyr::expand_grid(BMI = bmi_grid, .trt = trt_grid, data_prep %>% select(-c(BMI, .trt)))
# calculate predictions for specified time points
preds <-
ranger:::predict.ranger(rf_model, data = data_sets)$survival %>%
tibble::as_tibble() %>%
dplyr::select(all_of(time_points))
#merge predictions with data set and calculate 2D PDPs (average (survival probability) for each feature combination value per timepoint)
pdp_data <- data_sets %>% bind_cols(preds) %>%
filter(BMI %in% bmi_grid) %>%
pivot_longer(
c(colnames(preds)),
names_to = "time_id",
values_to = ".values"
) %>% mutate(time_id = stringr::str_replace(time_id, "V", "") %>% as.numeric()) %>%
group_by(BMI, .trt, time_id) %>% summarise(.value = mean(.values))
#Uncomment next line to save the outcome
qs::qsave(pdp_data, here::here("Data", "bmi_trt_int_pdp_preds.qs"))Plot PDPs
#rug
bmi_rug <- data_prep %>%
dplyr::pull(BMI)
pdp_data <- qs::qread(here::here("Data", "bmi_trt_int_pdp_preds.qs"))
#create range for BMI
bmi_range <- data_prep %>%
dplyr::pull(BMI) %>%
range()
#plot
p <- pdp_data %>%
ggplot2::ggplot(aes(
x = BMI,
y = 1 - .value,
color = .trt
)) + # 1 minus for event probability, not survival prob
geom_line() +
facet_wrap(
~ time_id,
nrow = 2,
labeller = ggplot2::labeller(
time_id = function(s) {
rf_model$unique.death.times[as.numeric(s)] %>% round(3)
},
# construct time labels within the function
.default = ggplot2::label_value
)
) +
ylab("event probability")
p + ggplot2::geom_rug(
data = bmi_rug %>% tibble::enframe(),
mapping = ggplot2::aes(x = value),
inherit.aes = F,
sides = "b",
alpha = 1,
col = "#B3B3B3"
) +
coord_cartesian(ylim = c(0, 0.5))Session info
R version 4.0.1 (2020-06-06) Platform: x86_64-pc-linux-gnu (64-bit) Running under: Ubuntu 18.04.6 LTS Matrix products: default BLAS: /opt/R/4.0.1/lib/R/lib/libRblas.so LAPACK: /opt/R/4.0.1/lib/R/lib/libRlapack.so attached base packages: [1] stats graphics grDevices utils datasets methods base other attached packages: [1] ranger_0.13.1 mlr3tuning_0.13.0 paradox_0.9.0 mlr3proba_0.4.0 [5] mlr3learners_0.4.5 mlr3_0.13.3 textrecipes_0.4.1 yardstick_0.0.8 [9] workflowsets_0.0.2 workflows_0.2.3 tune_0.1.5 rsample_0.1.0 [13] recipes_0.1.16 parsnip_0.2.1.9001 modeldata_0.1.1 infer_0.5.4 [17] dials_0.0.9 scales_1.1.1 broom_0.7.9 tidymodels_0.1.3 [21] skimr_2.1.4 qs_0.24.1 tictoc_1.0.1 here_1.0.1 [25] kableExtra_1.3.4 forcats_0.5.1 stringr_1.4.0 dplyr_1.0.7 [29] purrr_0.3.4 readr_1.4.0 tidyr_1.1.4 tibble_3.1.5 [33] ggplot2_3.3.5 tidyverse_1.3.1To cite R in publications use:
R Core Team (2020). R: A Language and Environment for Statistical Computing. R Foundation for Statistical Computing, Vienna, Austria. https://www.R-project.org/.
To cite the ggplot2 package in publications use:Wickham H (2016). ggplot2: Elegant Graphics for Data Analysis. Springer-Verlag New York. ISBN 978-3-319-24277-4, https://ggplot2.tidyverse.org.
Program: /home/antigoni_elefsinioti/sdi_ml_course/Demo/Demo_simulated_tte.Rmd
HTML template by Sebastian Voss of Chrestos GmbH & Co. KG on behalf of Biomarker & Data Insights, Bayer AG in 2019